import cv2
import numpy as np
import os
from os.path import exists
from imutils import paths
import pickle
from tqdm import tqdm
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing import image

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_size(file):
    """
    获取指定文件的大小（以MB为单位）
    
    参数:
    file (str): 文件的路径
    
    返回:
    float: 文件大小（MB）
    """
    return os.path.getsize(file) / (1024 * 1024)  # 文件大小（MB）

def createXY(train_folder, dest_folder, method='vgg', batch_size=64):
    x_file_path = os.path.join(dest_folder, "X.pkl")
    y_file_path = os.path.join(dest_folder, "y.pkl")

    # 如果 X 和 y 已经存在，则直接读取，不再重新构建
    if os.path.exists(x_file_path) and os.path.exists(y_file_path):
        logging.info("X和y已经存在，直接读取")
        logging.info(f"X文件大小:{get_size(x_file_path):.2f}MB")
        logging.info(f"y文件大小:{get_size(y_file_path):.2f}MB")
        
        # 用 pickle 读取X和y文件
        with open(x_file_path, 'rb') as f:
            X = pickle.load(f)
        with open(y_file_path, 'rb') as f:
            y = pickle.load(f)
        return X, y

    logging.info("读取所有图像，生成X和y")
    image_paths = list(paths.list_images(train_folder))

    X = []
    y = []

    if method == 'vgg':
        model = VGG16(weights='imagenet', include_top=False, pooling="max")
        logging.info("完成构建 VGG16 模型")
    elif method == 'flat':
        model = None

    num_batches = len(image_paths) // batch_size + (1 if len(image_paths) % batch_size else 0)

    for idx in tqdm(range(num_batches), desc="读取图像"):
        batch_images = []
        batch_labels = []
        
        start = idx * batch_size
        end = min((idx + 1) * batch_size, len(image_paths))

        for i in range(start, end):
            image_path = image_paths[i]
            if method == 'vgg':
                img = image.load_img(image_path, target_size=(224, 224))
                img = image.img_to_array(img)
            elif method == 'flat':
                img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)  # 以灰度模式读取图像
                img = cv2.resize(img, (32, 32))  # 调整图像大小到32x32
                img = img.flatten()  # 将图像展平
            
            batch_images.append(img)
            
            label = os.path.basename(image_path).split('_')[0]  # 从文件名中解析出字符串标签
            label = 1 if label == 'dog' else 0  # 如果标签是'dog'则为1,否则为0
            batch_labels.extend([label])
        
        batch_images = np.array(batch_images)
        if method == 'vgg':
            batch_images = preprocess_input(batch_images)
            batch_pixels = model.predict(batch_images, verbose=0)
        else:
            batch_pixels = np.array(batch_images)  # 如果是'flat'方法，X已经是特征

        X.extend(batch_pixels)
        y.extend(batch_labels)

    logging.info(f"X.shape: {np.shape(X)}")
    logging.info(f"y.shape: {np.shape(y)}")
        
    # 用 pickle 保存X和y文件
    with open(x_file_path, 'wb') as f:
        pickle.dump(X, f)
    with open(y_file_path, 'wb') as f:
        pickle.dump(y, f)

    return X, y
